import math
import random
import numpy as np
import util
import csv
import os
import torch.utils.data as Data
from torch.utils.data import Dataset
import torch
# from torch import nn as nn
# from torch import Tensor
# import torch.optim as optim
# from tqdm import tqdm
# from core.util import get_project_root, latest_checkpoint, softmax, segment
# from sklearn.model_selection import train_test_split
# from net2d.loss_function import LGMLoss_v0
# from torchinfo import summary
# from torch.nn import TransformerEncoder, TransformerEncoderLayer, TransformerDecoderLayer,TransformerDecoder
# from torch.utils.tensorboard import SummaryWriter
# from abc import ABCMeta, abstractmethod
# from sklearn.metrics import classification_report

class_dict = {
    'O': 1,
    '01-B': 2,
    '01-I': 2,
    '02-B': 3,
    '02-I': 3,
    '03-B': 4,
    '03-I': 4,
    '04-B': 5,
    '04-I': 5,
    '05-B': 6,
    '05-I': 6,
    '06-B': 7,
    '06-I': 7,
    '07-B': 8,
    '07-I': 8,
    '08-B': 9,
    '08-I': 9,
    '09-B': 10,
    '09-I': 10,
    '10-B': 11,
    '10-I': 11,
    '11-B': 12,
    '11-I': 12,
    '12-B': 2,
    '12-I': 2,
    '13-B': 3,
    '13-I': 3,
    '14-B': 4,
    '14-I': 4,
    '15-B': 5,
    '15-I': 5,
    '16-B': 6,
    '16-I': 6,
    '17-B': 7,
    '17-I': 7,
    '18-B': 8,
    '18-I': 8,
    '19-B': 9,
    '19-I': 9,
    '20-B': 10,
    '20-I': 10,
    '21-B': 11,
    '21-I': 11,
    '22-B': 12,
    '22-I': 12,
}

PROJECT_ROOT = util.get_project_root()

def five_grade_transform(num_class, grade):
    if num_class == 2:
        if grade > 0:
            return 1
        else:
            return 0
    elif num_class == 3:
        if grade ==4:
            return 2
        elif grade> 0:
            return 1
        else:
            return 0
    elif num_class ==5:
        return grade
    else:
        assert False, "Grade Error"

class Data_Toolbox():
    def __init__(self, num_classes,):
        super(Data_Toolbox, self).__init__()
        self.num_classes = num_classes
        self.feature_path = None

    # 读取csv文件生成list
    @staticmethod
    def read_seq_annotation(filename):
        annotation = []
        with open(filename, 'r') as f:
            reader = csv.reader(f)
            for index, label in reader:
                annotation.append(class_dict[label])
        return annotation

    def generate_trainning_dataset(self, feature_path, scocer_path, fps):
        num_classes = self.num_classes
        seq_root_path = "{}/data/74_fps{}/BIO".format(PROJECT_ROOT, fps)

        annotation = {}
        # 加载 得分数据
        with open(scocer_path, 'r') as f:
            reader = csv.reader(f)
            for index, file_code, label, s1, s2 in reader:
                file_code = file_code.strip()
                annotation[file_code] = {'grade_label': int(label), 'scores': (int(s1), int(s2)), 'seq_label': None}

        # 加载 序列单帧图像 分类
        for index, item in annotation.items():
            seq_file = "{}/{}.csv".format(seq_root_path, index)
            seq_label = Data_Toolbox.read_seq_annotation(seq_file)
            annotation[index]['seq_label'] = np.array(seq_label, )

        # 加载 特征
        self.feature_path = feature_path
        dataset = []
        for feature_file in os.listdir(feature_path):
            if os.path.splitext(feature_file)[-1] == ".npz":
                filename = os.path.join(feature_path, feature_file)
                data = np.load(filename, allow_pickle=True)
                feature, fcode = data["x"], data["v"]
                fcode = str(fcode)
                grade_label = annotation[fcode]['grade_label']
                grade_label = five_grade_transform(num_classes, grade_label)
                seq_label = annotation[fcode]['seq_label']
                score_label = annotation[fcode]['scores']

                sample_len = np.size(feature, axis=0)
                seq_size = len(seq_label)
                if sample_len < seq_size:
                    seq_label = seq_label[:sample_len]
                dataset.append({'grade_label': grade_label, 'seq_label': seq_label, 'score_label': score_label,
                                'fcode': fcode, 'feature_file': feature_file, 'feature_index':None})

        return dataset

    def data_augment(self, train_data, p_scale, n_scale):
        n = len(train_data)
        m = int(n * p_scale)
        # 正样本 扩增
        new_data = []
        for i in range(m):
            sparse = 0.02 + 0.01 * random.random()
            sid = random.randint(0, n - 1)
            sample = train_data[sid]

            grade_label = sample["grade_label"]
            seq_label = sample["seq_label"]
            score_label = sample["score_label"]
            fcode = sample["fcode"]
            feature_file = sample["feature_file"]

            filename = os.path.join(self.feature_path, feature_file)
            data = np.load(filename, allow_pickle=True)
            feature, _fcode = data["x"], data["v"]

            assert fcode == str(_fcode), "数据加载不对应"

            sample_len = np.size(feature, axis=0)
            seq_size = len(seq_label)
            feat_index = np.random.random(sample_len) >= sparse
            assert sample_len == seq_size, "特征与标注的长度不符"
            # if (sample_len < seq_size):
            #     seq_label = seq_label[:sample_len]
            new_seq_label = seq_label[feat_index]
            new_data.append({'grade_label': grade_label, 'seq_label': new_seq_label, 'score_label': score_label,
                             'fcode': f"{fcode}_p", 'feature_file': feature_file, 'feature_index': feat_index})

        # 负样本 扩增
        new2_data = []
        k = int(n * n_scale)
        for i in range(k):
            s = 0.125 + 0.04 * random.random()
            sid = random.randint(0, n - 1)
            sample = train_data[sid]

            grade_label = 0
            seq_label = sample["seq_label"]
            score_label = (0, 0)
            fcode = sample["fcode"]
            feature_file = sample["feature_file"]

            filename = os.path.join(self.feature_path, feature_file)
            data = np.load(filename, allow_pickle=True)
            feature, _fcode = data["x"], data["v"]

            assert fcode == str(_fcode), "数据加载不对应"

            sample_len = np.size(feature, axis=0)
            seq_size = len(seq_label)
            seg_len = int(s * sample_len)
            start = np.random.randint(10, sample_len - seg_len - 10)
            feat_index = np.zeros((sample_len),dtype=np.bool_)
            feat_index[start:start + seg_len] = True
            new_seq_label = seq_label[start:start + seg_len]
            new2_data.append({'grade_label': grade_label, 'seq_label': new_seq_label, 'score_label': score_label,
                             'fcode': f"{fcode}_n", 'feature_file': feature_file, 'feature_index': feat_index})
        train_data.extend(new_data)
        train_data.extend(new2_data)
        return train_data

    def create_train_set(self, feature_path, scocer_path, dataset_code, fps):
        dataset = self.generate_trainning_dataset(feature_path, scocer_path, fps)
        dataset = self.data_augment(dataset, p_scale=2.0, n_scale=1.0)
        print("len =",len(dataset))
        trian_data_path = f"{PROJECT_ROOT}/train_data/fps{fps}/{dataset_code}.npz"
        np.savez(trian_data_path, data=dataset)

class Sequence_Data_V1(Dataset):
    '''
        基于图像特征的 序列数据
    '''
    def __init__(self, feature_path, data_set, bptt, step):
        '''
        初始化数据集
        :param feature: 由前端网络所的提取的基于图像的特征序列
        :param label: 对应的动作标注序列
        :param bptt: bptt
        '''
        super(Sequence_Data_V1, self).__init__()
        self.feature_path = feature_path
        self.data_set = data_set
        self.bptt = bptt
        self.step = step
        assert bptt % step ==0, f"bptt({bptt})不能整除step({step})"

    def __len__(self):
        return len(self.data_set)

    def __getitem__(self, i):
        bptt = self.bptt

        sample = self.data_set[i]

        grade_label = sample["grade_label"]
        seq_label = sample["seq_label"]
        score_label = sample["score_label"]
        feature_file = sample["feature_file"]
        feat_index = sample["feature_index"]

        filename = os.path.join(self.feature_path, feature_file)
        data = np.load(filename, allow_pickle=True)
        feature, _fcode = data["x"], data["v"]
        if feat_index is not None:
            X = feature[feat_index]
            y, (s1, s2), seq_y= grade_label, score_label, seq_label
        else:
            X = feature
            y, (s1, s2), seq_y= grade_label, score_label, seq_label

        length = len(seq_y)
        seg_count = length // bptt

        X = X[:seg_count * bptt]
        tX = torch.from_numpy(X)

        seq_y = seq_y[:seg_count * bptt]
        sy = np.reshape(seq_y, (-1, self.step))
        seq_y = np.zeros((sy.shape[0]))
        for i, item in enumerate(sy):
            d = np.argmax(np.bincount(item, minlength=24)) # 24种动作
            seq_y[i] = d
        return tX, y, torch.from_numpy(seq_y, ).long(), seg_count, torch.FloatTensor((s1 / 100, s2 / 100))

if __name__ == '__main__':
    dtool = Data_Toolbox(num_classes=3)
    feature_path = f"{PROJECT_ROOT}/train_data/fps6/res2net50"
    scocer_path = f"{PROJECT_ROOT}/data/score/score_two.csv"
    dtool.create_train_set(feature_path, scocer_path, dataset_code='res2net50_fps6_trans', fps=6)

    # feature_path = f"{PROJECT_ROOT}/train_data/fps12/resnet2p1d18"
    # scocer_path = f"{PROJECT_ROOT}/data/score/score_two.csv"
    # dtool.create_train_set(feature_path, scocer_path, dataset_code='resnet2p1d18_fps12_trans', fps=12)